import torch
import copy
import numpy as np
import time
import scipy.stats as stats 
import math
import safetensors

weight = torch.load("./llama/org_models/7B/consolidated.00.pth", map_location='cpu')
E_exp2 = 0
E_fi_2_dict = {}

def Q(value,mu,sigma):
    return 1 - stats.norm.cdf(value,mu,sigma)

def B_ij(macc,mp,i,j):
    return 2*(Q(2**(macc+1+mp-j)/((2*i*torch.pi)**(1/2)),0,1) - Q(2**(macc+2+mp-j)/((2*i*torch.pi)**(1/2)),0,1))

def C(n):
    sum_ = 1.
    for i in range(1,n):
        sum_ *= 1 - 2*Q(2**(macc+1)/((2*i*torch.pi)**(1/2)),0,1)
    return sum_

def E_i_2exp(sigma):
    sum_ = 0 
    for e in range(-50,50):
        e_pro = 2*(Q(2**e/sigma,0,1) - Q(2**(e+1)/sigma,0,1))
        sum_ += e_pro * (2**e)
    return sum_

def E_fij_2(j,mp):
    global E_exp2
    e = E_exp2
    return 2**(-2*mp)*(2**j - 1)*(2**(j+1) - 1)*e/6

def E_fi_2(i,mp,macc):
    sum_ = 0
    for j in range(1, mp+1):
        Bij = B_ij(macc,mp,i,j)
        sum_ += Bij * E_fij_2(j,mp)
    return sum_

def E_Sn2_swamping_i(i,mp,macc,sigma):
    global E_fi_2_dict
    sum_ = 0
    for q in range(1, i+1):
        if q in E_fi_2_dict:
            e = E_fi_2_dict[q]
        else:
            e = E_fi_2(q,mp,macc)
            E_fi_2_dict[q] = e
        sum_ += e
    return i*(sigma**2) - sum_

def cal_Sn2_swamp(n, mp, macc, sigma):
    gamma = 0.
    sum_ = 0.
    for i in range(1,n+1):
        if i == 1:
            A_i = 2*Q(2**(macc+1)/((2*torch.pi)**(1/2)),0,1)
            gamma += A_i
        elif i == n:
            if n == 1 or n == 2:
                A_i_2 = 1.
            A_i_2 *= (1 - 2 * Q(2**(macc+1)/((2*(i-1)*torch.pi)**(1/2)),0,1))
            if A_i_2 != 0:
                e_Sn2_swamping_i = E_Sn2_swamping_i(i,mp,macc,sigma)
                sum_ += A_i_2*e_Sn2_swamping_i
                gamma += A_i_2
                break
        else:
            if i==2:
                A_i = 2*Q(2**(macc+1)/((2*i*torch.pi)**(1/2)),0,1)
                A_i_2 = (1 - 2 * Q(2**(macc+1)/((2*torch.pi)**(1/2)),0,1))
            else:
                A_i = 2*Q(2**(macc+1)/((2*i*torch.pi)**(1/2)),0,1)
                A_i_2 *= (1 - 2 * Q(2**(macc+1)/((2*(i-1)*torch.pi)**(1/2)),0,1))
            A_i *= A_i_2
            gamma += A_i
        if A_i != 0:
            e_Sn2_swamping_i = E_Sn2_swamping_i(i,mp,macc,sigma)
            sum_ += A_i*e_Sn2_swamping_i
    return sum_

weight_2_sum = 0.
counts = 0
for key in weight.keys():
    if 'weight' in key and 'norm' not in key:
        weight_ = weight[key].double()
        weight_2 = weight_**2
        weight_2_sum += torch.sum(weight_2)
        counts += weight_.shape[0] * weight_.shape[1]

K = 128
mp = 8
n = 4096
segment_label = 1

weight_2_mean = weight_2_sum / counts
Dstep = weight_2_mean
sigma = (K*Dstep)**(1/2)
E_exp2 = E_i_2exp(sigma)**2

Sn2_ideal = n*(sigma**(2))

n_fp = n // K
k1 = int(n_fp ** (1/2))
for macc in range(1,24):
    E_fi_2_dict = {}
    if segment_label == 0:
        FnRR = cal_Sn2_swamp(n_fp,mp,macc,sigma)/(n_fp*(sigma**2))
    else:
        Sn2_swamp = cal_Sn2_swamp(k1,mp,macc,sigma)
        FnRR = Sn2_swamp/(k1*(sigma**2))
        sigma = Sn2_swamp ** (1/2)
        E_exp2 = E_i_2exp(sigma)**2
        FnRR *= cal_Sn2_swamp(n_fp//k1,macc,macc,sigma)/(n_fp//k1*(sigma**2))
    print(n_fp,macc,math.exp(n_fp*(1-FnRR**(1/2))))